{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import numpy as np \n",
    "import seaborn as sns \n",
    "import matplotlib.pyplot as plt "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('titanic_train.csv')\n",
    "test = pd.read_csv('titanic_test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>891.000000</td>\n",
       "      <td>891.000000</td>\n",
       "      <td>891.000000</td>\n",
       "      <td>714.000000</td>\n",
       "      <td>891.000000</td>\n",
       "      <td>891.000000</td>\n",
       "      <td>891.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>446.000000</td>\n",
       "      <td>0.383838</td>\n",
       "      <td>2.308642</td>\n",
       "      <td>29.699118</td>\n",
       "      <td>0.523008</td>\n",
       "      <td>0.381594</td>\n",
       "      <td>32.204208</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>257.353842</td>\n",
       "      <td>0.486592</td>\n",
       "      <td>0.836071</td>\n",
       "      <td>14.526497</td>\n",
       "      <td>1.102743</td>\n",
       "      <td>0.806057</td>\n",
       "      <td>49.693429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.420000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>223.500000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>20.125000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>7.910400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>446.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>28.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>14.454200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>668.500000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>38.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>31.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>891.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>80.000000</td>\n",
       "      <td>8.000000</td>\n",
       "      <td>6.000000</td>\n",
       "      <td>512.329200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       PassengerId    Survived      Pclass         Age       SibSp  \\\n",
       "count   891.000000  891.000000  891.000000  714.000000  891.000000   \n",
       "mean    446.000000    0.383838    2.308642   29.699118    0.523008   \n",
       "std     257.353842    0.486592    0.836071   14.526497    1.102743   \n",
       "min       1.000000    0.000000    1.000000    0.420000    0.000000   \n",
       "25%     223.500000    0.000000    2.000000   20.125000    0.000000   \n",
       "50%     446.000000    0.000000    3.000000   28.000000    0.000000   \n",
       "75%     668.500000    1.000000    3.000000   38.000000    1.000000   \n",
       "max     891.000000    1.000000    3.000000   80.000000    8.000000   \n",
       "\n",
       "            Parch        Fare  \n",
       "count  891.000000  891.000000  \n",
       "mean     0.381594   32.204208  \n",
       "std      0.806057   49.693429  \n",
       "min      0.000000    0.000000  \n",
       "25%      0.000000    7.910400  \n",
       "50%      0.000000   14.454200  \n",
       "75%      0.000000   31.000000  \n",
       "max      6.000000  512.329200  "
      ]
     },
     "execution_count": 119,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([train, test], axis=0)\n",
    "df.drop('PassengerId', axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "Int64Index: 1309 entries, 0 to 417\n",
      "Data columns (total 11 columns):\n",
      " #   Column    Non-Null Count  Dtype  \n",
      "---  ------    --------------  -----  \n",
      " 0   Survived  891 non-null    float64\n",
      " 1   Pclass    1309 non-null   int64  \n",
      " 2   Name      1309 non-null   object \n",
      " 3   Sex       1309 non-null   object \n",
      " 4   Age       1046 non-null   float64\n",
      " 5   SibSp     1309 non-null   int64  \n",
      " 6   Parch     1309 non-null   int64  \n",
      " 7   Ticket    1309 non-null   object \n",
      " 8   Fare      1308 non-null   float64\n",
      " 9   Cabin     295 non-null    object \n",
      " 10  Embarked  1307 non-null   object \n",
      "dtypes: float64(3), int64(3), object(5)\n",
      "memory usage: 122.7+ KB\n"
     ]
    }
   ],
   "source": [
    "df.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Survived       1.000000\n",
       "Fare           0.257307\n",
       "Parch          0.081629\n",
       "PassengerId   -0.005007\n",
       "SibSp         -0.035322\n",
       "Age           -0.077221\n",
       "Pclass        -0.338481\n",
       "Name: Survived, dtype: float64"
      ]
     },
     "execution_count": 122,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.corr()['Survived'].sort_values(ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['Age'] = df['Age'].fillna(df['Age'].mean())\n",
    "df['Age'] = df['Age'].apply(np.log1p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "df['Fare'] = df['Fare'].fillna(df['Fare'].mean())\n",
    "df['Fare'] = df['Fare'].apply(np.log1p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['Cabin'] = df['Cabin'].fillna('None')\n",
    "df['cabin_captial'] = df['Cabin'].apply(lambda x: x[:1])\n",
    "df.drop('Cabin', axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['Embarked'] = df['Embarked'].fillna(df['Embarked'].mode()[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.drop(['Name', 'Ticket'], axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['Pclass'] = df['Pclass'].apply(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['family_size'] = df['SibSp']+df['Parch']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['pc'] = df['Pclass']+df['cabin_captial']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'cabin_cap'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-111-ac78ec39d1a4>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;31m# 不同性别，船舱等级的死亡率\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mtrain\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroupby\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'Sex'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Pclass'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'cabin_cap'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'Survived'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\pandas\\core\\frame.py\u001b[0m in \u001b[0;36mgroupby\u001b[1;34m(self, by, axis, level, as_index, sort, group_keys, squeeze, observed)\u001b[0m\n\u001b[0;32m   5808\u001b[0m             \u001b[0mgroup_keys\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mgroup_keys\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   5809\u001b[0m             \u001b[0msqueeze\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 5810\u001b[1;33m             \u001b[0mobserved\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mobserved\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   5811\u001b[0m         )\n\u001b[0;32m   5812\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\pandas\\core\\groupby\\groupby.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, obj, keys, axis, level, grouper, exclusions, selection, as_index, sort, group_keys, squeeze, observed, mutated)\u001b[0m\n\u001b[0;32m    407\u001b[0m                 \u001b[0msort\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msort\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    408\u001b[0m                 \u001b[0mobserved\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mobserved\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 409\u001b[1;33m                 \u001b[0mmutated\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmutated\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    410\u001b[0m             )\n\u001b[0;32m    411\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\pandas\\core\\groupby\\grouper.py\u001b[0m in \u001b[0;36mget_grouper\u001b[1;34m(obj, key, axis, level, sort, observed, mutated, validate)\u001b[0m\n\u001b[0;32m    596\u001b[0m                 \u001b[0min_axis\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgpr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgpr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    597\u001b[0m             \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 598\u001b[1;33m                 \u001b[1;32mraise\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgpr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    599\u001b[0m         \u001b[1;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgpr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mGrouper\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mgpr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mkey\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    600\u001b[0m             \u001b[1;31m# Add key to exclusions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyError\u001b[0m: 'cabin_cap'"
     ]
    }
   ],
   "source": [
    "# 不同性别，船舱等级的死亡率\n",
    "train.groupby(['Sex', 'Pclass', 'cabin_cap'])['Survived'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'cabin_cap'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-112-bc142211fb6f>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mtrain\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgroupby\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'Sex'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'Pclass'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'cabin_cap'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;34m'Survived'\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\pandas\\core\\frame.py\u001b[0m in \u001b[0;36mgroupby\u001b[1;34m(self, by, axis, level, as_index, sort, group_keys, squeeze, observed)\u001b[0m\n\u001b[0;32m   5808\u001b[0m             \u001b[0mgroup_keys\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mgroup_keys\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   5809\u001b[0m             \u001b[0msqueeze\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 5810\u001b[1;33m             \u001b[0mobserved\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mobserved\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   5811\u001b[0m         )\n\u001b[0;32m   5812\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\pandas\\core\\groupby\\groupby.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, obj, keys, axis, level, grouper, exclusions, selection, as_index, sort, group_keys, squeeze, observed, mutated)\u001b[0m\n\u001b[0;32m    407\u001b[0m                 \u001b[0msort\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msort\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    408\u001b[0m                 \u001b[0mobserved\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mobserved\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 409\u001b[1;33m                 \u001b[0mmutated\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmutated\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    410\u001b[0m             )\n\u001b[0;32m    411\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python36\\site-packages\\pandas\\core\\groupby\\grouper.py\u001b[0m in \u001b[0;36mget_grouper\u001b[1;34m(obj, key, axis, level, sort, observed, mutated, validate)\u001b[0m\n\u001b[0;32m    596\u001b[0m                 \u001b[0min_axis\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlevel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgpr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mgpr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    597\u001b[0m             \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 598\u001b[1;33m                 \u001b[1;32mraise\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgpr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    599\u001b[0m         \u001b[1;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mgpr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mGrouper\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mgpr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mkey\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    600\u001b[0m             \u001b[1;31m# Add key to exclusions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyError\u001b[0m: 'cabin_cap'"
     ]
    }
   ],
   "source": [
    "train.groupby(['Sex', 'Pclass', 'cabin_cap'])['Survived'].count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 134,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Embarked</th>\n",
       "      <th>cabin_captial</th>\n",
       "      <th>family_size</th>\n",
       "      <th>pc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>3.135494</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2.110213</td>\n",
       "      <td>S</td>\n",
       "      <td>N</td>\n",
       "      <td>1</td>\n",
       "      <td>3N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>3.663562</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>4.280593</td>\n",
       "      <td>C</td>\n",
       "      <td>C</td>\n",
       "      <td>1</td>\n",
       "      <td>1C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.0</td>\n",
       "      <td>3</td>\n",
       "      <td>female</td>\n",
       "      <td>3.295837</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2.188856</td>\n",
       "      <td>S</td>\n",
       "      <td>N</td>\n",
       "      <td>0</td>\n",
       "      <td>3N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.0</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>3.583519</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>3.990834</td>\n",
       "      <td>S</td>\n",
       "      <td>C</td>\n",
       "      <td>1</td>\n",
       "      <td>1C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.0</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>3.583519</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2.202765</td>\n",
       "      <td>S</td>\n",
       "      <td>N</td>\n",
       "      <td>0</td>\n",
       "      <td>3N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>413</th>\n",
       "      <td>NaN</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>3.430146</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2.202765</td>\n",
       "      <td>S</td>\n",
       "      <td>N</td>\n",
       "      <td>0</td>\n",
       "      <td>3N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>414</th>\n",
       "      <td>NaN</td>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>3.688879</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>4.699571</td>\n",
       "      <td>C</td>\n",
       "      <td>C</td>\n",
       "      <td>0</td>\n",
       "      <td>1C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>415</th>\n",
       "      <td>NaN</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>3.676301</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2.110213</td>\n",
       "      <td>S</td>\n",
       "      <td>N</td>\n",
       "      <td>0</td>\n",
       "      <td>3N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>416</th>\n",
       "      <td>NaN</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>3.430146</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2.202765</td>\n",
       "      <td>S</td>\n",
       "      <td>N</td>\n",
       "      <td>0</td>\n",
       "      <td>3N</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>417</th>\n",
       "      <td>NaN</td>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>3.430146</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>3.150952</td>\n",
       "      <td>C</td>\n",
       "      <td>N</td>\n",
       "      <td>2</td>\n",
       "      <td>3N</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1309 rows × 11 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     Survived Pclass     Sex       Age  SibSp  Parch      Fare Embarked  \\\n",
       "0         0.0      3    male  3.135494      1      0  2.110213        S   \n",
       "1         1.0      1  female  3.663562      1      0  4.280593        C   \n",
       "2         1.0      3  female  3.295837      0      0  2.188856        S   \n",
       "3         1.0      1  female  3.583519      1      0  3.990834        S   \n",
       "4         0.0      3    male  3.583519      0      0  2.202765        S   \n",
       "..        ...    ...     ...       ...    ...    ...       ...      ...   \n",
       "413       NaN      3    male  3.430146      0      0  2.202765        S   \n",
       "414       NaN      1  female  3.688879      0      0  4.699571        C   \n",
       "415       NaN      3    male  3.676301      0      0  2.110213        S   \n",
       "416       NaN      3    male  3.430146      0      0  2.202765        S   \n",
       "417       NaN      3    male  3.430146      1      1  3.150952        C   \n",
       "\n",
       "    cabin_captial  family_size  pc  \n",
       "0               N            1  3N  \n",
       "1               C            1  1C  \n",
       "2               N            0  3N  \n",
       "3               C            1  1C  \n",
       "4               N            0  3N  \n",
       "..            ...          ...  ..  \n",
       "413             N            0  3N  \n",
       "414             C            0  1C  \n",
       "415             N            0  3N  \n",
       "416             N            0  3N  \n",
       "417             N            2  3N  \n",
       "\n",
       "[1309 rows x 11 columns]"
      ]
     },
     "execution_count": 134,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 149,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[:len(train)].iloc[df[(df['Pclass']!='1') & (df['Sex']=='male')].index, 0].sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [],
   "source": [
    "train.fillna('None', inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "metadata": {},
   "outputs": [],
   "source": [
    "train['cabin_cap'] = train['Cabin'].apply(lambda x: x[:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0      N\n",
       "1      C\n",
       "2      N\n",
       "3      C\n",
       "4      N\n",
       "      ..\n",
       "886    N\n",
       "887    B\n",
       "888    N\n",
       "889    C\n",
       "890    N\n",
       "Name: cabin_cap, Length: 891, dtype: object"
      ]
     },
     "execution_count": 152,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train['cabin_cap']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ = pd.get_dummies(df)\n",
    "\n",
    "train_ = df_[:len(train)]\n",
    "test_ = df_[len(train):]\n",
    "\n",
    "train_y = train_['Survived']\n",
    "train_x = train_.drop('Survived', axis=1)\n",
    "test_x = test_.drop('Survived', axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## modeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import GridSearchCV, KFold, StratifiedKFold, cross_val_score, train_test_split\n",
    "from sklearn.linear_model import Lasso, Ridge, LogisticRegression\n",
    "from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier\n",
    "import lightgbm\n",
    "\n",
    "kfold = StratifiedKFold(n_splits=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 2 concurrent workers.\n",
      "[Parallel(n_jobs=-1)]: Done   5 out of   5 | elapsed:    3.7s finished\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scoring:  0.8249450756386917\n",
      "params:  {'criterion': 'gini', 'max_depth': None, 'max_features': 5, 'min_samples_leaf': 3, 'min_samples_split': 5, 'n_estimators': 100}\n"
     ]
    }
   ],
   "source": [
    "# 随机森林\n",
    "rfc = RandomForestClassifier()\n",
    "\n",
    "rf_param_grid = {\n",
    "    'max_depth': [None],\n",
    "    'max_features': [5],\n",
    "    'min_samples_split': [5],\n",
    "    'min_samples_leaf': [3],\n",
    "    'n_estimators': [100],\n",
    "    'criterion': ['gini']\n",
    "}\n",
    "\n",
    "gs_rfc = GridSearchCV(rfc, \n",
    "                      param_grid=rf_param_grid, \n",
    "                      cv=kfold,\n",
    "                      scoring='accuracy',\n",
    "                      n_jobs=-1,\n",
    "                      verbose=-1\n",
    "                     )\n",
    "gs_rfc.fit(train_x, train_y)\n",
    "\n",
    "print('scoring: ', gs_rfc.best_score_)\n",
    "print('params: ', gs_rfc.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(clf):\n",
    "    score = cross_val_score(clf, train_x, train_y, scoring='accuracy', cv=kfold, n_jobs=-1)\n",
    "    print('scoring', score)\n",
    "    print('score', score.mean(), score.std())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scoring [0.77094972 0.76966292 0.8258427  0.84269663 0.88764045]\n",
      "score 0.8193584834599209 0.044860223258249614\n"
     ]
    }
   ],
   "source": [
    "rfc_ = RandomForestClassifier(max_features=5,\n",
    "                             min_samples_split=5,\n",
    "                             min_samples_leaf=3,\n",
    "                             n_estimators=200,\n",
    "                             criterion='gini')\n",
    "evaluate(rfc_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "score:  0.8002699140041429\n",
      "params:  {'C': 0.03}\n"
     ]
    }
   ],
   "source": [
    "# 逻辑回归\n",
    "lr = LogisticRegression()\n",
    "\n",
    "lr_params = {\n",
    "    'C': [0.03],\n",
    "}\n",
    "\n",
    "lr_gs = GridSearchCV(lr,\n",
    "                   param_grid=lr_params,\n",
    "                   cv = kfold,\n",
    "                   scoring='accuracy',\n",
    "                   n_jobs=-1,\n",
    "                  )\n",
    "\n",
    "lr_gs.fit(train_x, train_y)\n",
    "\n",
    "print('score: ', lr_gs.best_score_)\n",
    "print('params: ', lr_gs.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scoring [0.75977654 0.80337079 0.80898876 0.79213483 0.83707865]\n",
      "score 0.8002699140041429 0.025085494337963445\n"
     ]
    }
   ],
   "source": [
    "lr_ = LogisticRegression(C=0.03)\n",
    "evaluate(lr_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "score:  0.8047768501663424\n",
      "params:  {'learning_rate': 1, 'n_estimators': 50}\n"
     ]
    }
   ],
   "source": [
    "# Adaboost\n",
    "ada = AdaBoostClassifier()\n",
    "\n",
    "ada_params = {\n",
    "    'learning_rate': [0.3, 1],\n",
    "    'n_estimators': [50, 100]\n",
    "}\n",
    "ada_gs = GridSearchCV(ada,\n",
    "                      param_grid=ada_params,\n",
    "                      cv=kfold,\n",
    "                      scoring='accuracy',\n",
    "                      n_jobs=-1\n",
    "                     )\n",
    "\n",
    "ada_gs.fit(train_x, train_y)\n",
    "\n",
    "print('score: ', ada_gs.best_score_)\n",
    "print('params: ', ada_gs.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "ada_preds = ada_gs.best_estimator_.predict(test_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scoring [0.76536313 0.81460674 0.82022472 0.82022472 0.83707865]\n",
      "score 0.8114995919904588 0.02426838117295715\n"
     ]
    }
   ],
   "source": [
    "ada_ = AdaBoostClassifier(learning_rate=0.3,\n",
    "                         n_estimators=1000)\n",
    "evaluate(ada_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8350574351892537\n",
      "{'bagging_fraction': 0.8, 'cat_smooth': 0, 'feature_feaction': 0.8, 'learning_rate': 0.1, 'max_depth': 5, 'num_leaves': 7, 'objective': 'binary'}\n"
     ]
    }
   ],
   "source": [
    "# lightgbm \n",
    "lgbc = lightgbm.LGBMClassifier()\n",
    "lgb_param = {\n",
    "    'objective':['binary'],\n",
    "    'max_depth':[5],\n",
    "    'num_leaves':[7],\n",
    "    'learning_rate':[0.1],\n",
    "    'feature_feaction':[0.8],\n",
    "    'bagging_fraction':[0.8],\n",
    "    'cat_smooth':[0],\n",
    "}\n",
    "lgbm_gs = GridSearchCV(lgbc,\n",
    "                       param_grid=lgb_param,\n",
    "                       cv=kfold,\n",
    "                       scoring='accuracy',\n",
    "                       n_jobs=-1\n",
    "                      )\n",
    "lgbm_gs.fit(train_x, train_y)\n",
    "\n",
    "print(lgbm_gs.best_score_)\n",
    "print(lgbm_gs.best_params_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "scoring [0.79329609 0.8258427  0.84831461 0.81460674 0.86516854]\n",
      "score 0.8294457347310275 0.025187804823704026\n"
     ]
    }
   ],
   "source": [
    "lgbc_ = lightgbm.LGBMClassifier(objective='binary',\n",
    "                                max_depth=4,\n",
    "                                num_leaves=6,\n",
    "                                learning_rate=0.3,\n",
    "                                bagging_fraction=0.5,\n",
    "                                feature_fraction=0.6,\n",
    "                                cat_smooth=0\n",
    "                               )\n",
    "evaluate(lgbc_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_preds = lr_gs.predict(test_x)\n",
    "\n",
    "rf_preds = gs_rfc.best_estimator_.predict(test_x)\n",
    "\n",
    "lgbm_preds = lgbm_gs.best_estimator_.predict(test_x)\n",
    "\n",
    "sub = pd.DataFrame()\n",
    "sub['PassengerId'] = test.PassengerId\n",
    "sub['Survived'] = rf_preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "test['Survived'] = rf_preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "test['Cabin'].fillna('None', inplace=True)\n",
    "test['cabin_cap'] = test['Cabin'].apply(lambda x: x[0])\n",
    "male_3_N = test[(test['Pclass']!=1) & (test['Sex']=='male') & (test['cabin_cap']=='N')].index\n",
    "female_1_2 = test[(test['Pclass']!=3) & (test['Sex']=='female')].index\n",
    "\n",
    "test.iloc[male_3_N, 11] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 164,
   "metadata": {},
   "outputs": [],
   "source": [
    "sub['Survived'] = rf_preds\n",
    "sub['Survived'] = sub['Survived'].apply(np.int)\n",
    "sub.to_csv('submission_2020_03_17.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 165,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Survived</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>892</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>893</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>894</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>895</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>896</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>413</th>\n",
       "      <td>1305</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>414</th>\n",
       "      <td>1306</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>415</th>\n",
       "      <td>1307</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>416</th>\n",
       "      <td>1308</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>417</th>\n",
       "      <td>1309</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>418 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     PassengerId  Survived\n",
       "0            892         0\n",
       "1            893         0\n",
       "2            894         0\n",
       "3            895         0\n",
       "4            896         0\n",
       "..           ...       ...\n",
       "413         1305         0\n",
       "414         1306         1\n",
       "415         1307         0\n",
       "416         1308         0\n",
       "417         1309         0\n",
       "\n",
       "[418 rows x 2 columns]"
      ]
     },
     "execution_count": 165,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sub"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
